package wang.chunfan.code.websocket.server;

import com.alibaba.fastjson.JSON;
import org.springframework.stereotype.Component;
import wang.chunfan.code.websocket.model.Message;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author wangchunfan
 */
@ServerEndpoint("/webSocket/{sid}")
@Component
public class WebSocketServer {

    // 当前在线用户数量
    private static AtomicInteger onlineCount = new AtomicInteger(0);
    // 当前在线用户 session
    private static ConcurrentHashMap<String, WebSocketServer> webSocketServers = new ConcurrentHashMap<String, WebSocketServer>();
    // 连接客户端
    private Session session;
    // 客户端用户 sid
    private String sid;

    /**
     * 连接成功
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("sid") String sid) {
        this.session = session;
        this.sid = sid;
        onlineCount.getAndIncrement();
        webSocketServers.put(sid, this);
        System.out.println("连接成功：" + this.sid);

        // 发送当前在线人数给客户端
        Message message = new Message();
        message.setType("onlineCount");
        message.setData(String.valueOf(getOnlineCount()));
        sendMessage(JSON.toJSONString(message), null);
    }

    /**
     * 接受客户端信息
     */
    @OnMessage
    public void onMessage(Session session, String message) {
        System.out.println("收到消息：" + this.sid + ";message:" + message);
        Message msg = JSON.parseObject(message, Message.class);

        // 通知客户端收到消息
        switch (msg.getType()) {
            // 客户端与服务端单独通信
            case "single":
                sendMessage(message);
                break;
            // 客户端向其它客户端群发消息
            case "group":
                sendMessage(message, null);
                break;
            // 心跳检测
            case "heartbeat":
                break;
        }
    }

    /**
     * 连接错误
     */
    @OnError
    public void onError(Session session, Throwable error) {
        System.out.println("连接错误：" + this.sid);
        error.printStackTrace();
    }

    /**
     * 连接断开
     */
    @OnClose
    public void onClose() {
        webSocketServers.remove(this);
        onlineCount.getAndDecrement();

        System.out.println("连接断开:" + this.sid);
        Message message = new Message();
        message.setData(String.valueOf(getOnlineCount()));
        message.setType("onlineCount");
        sendMessage(JSON.toJSONString(message), null);
    }

    /**
     * 获取在线用户数
     *
     * @return
     */
    public static int getOnlineCount() {
        return onlineCount.intValue();
    }

    // 向客户端发送信息
    public void sendMessage(String message) {
        try {
            synchronized (this.session) {
                this.session.getBasicRemote().sendText(message);
            }
        } catch (IllegalStateException e) {
            System.out.println("sid:" + this.sid);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 群发消息
     * 如果没有指定 sid 则群发
     * 如果指定 sid 则单独发
     */
    public void sendMessage(String message, @PathParam("sid") String sid) {
        if (sid == null) {
            for (WebSocketServer webSocketServer : webSocketServers.values()) {
                webSocketServer.sendMessage(message);
            }
        } else {
            WebSocketServer webSocketServer = webSocketServers.get(sid);
            webSocketServer.sendMessage(message);
        }
    }
}
